{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "not_data = np.array(pd.read_excel('noreply.xlsx',sheet_name='Sheet1',usecols=[0])).flatten()\n",
    "need_data = np.array(pd.read_excel('needreply2.xlsx',usecols=[1])).flatten()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "not_data 14027 ['2018住房转移记录是什么意思' '一个月贵多少来着' '那么高的价格？你是2层的，那单层的什么价格' ... '拉微信和支付宝吗'\n",
      " '没去看过绿城万科样板房，绿万哪些地方好些啊' '地暖多少钱啊']\n",
      "need_data 41756 ['销售，有吗？销售联系方式' '销售，有吗？销售联系方式' '销售，有吗？销售联系方式' ... '倒挂，多少？' '倒挂，多少？'\n",
      " '倒挂，多少？']\n"
     ]
    }
   ],
   "source": [
    "print('not_data', len(not_data),not_data)\n",
    "print('need_data', len(need_data),need_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "not_data ['幼子一辈子只能当接盘侠 一辈子只能给巨子打工？' '中海有什么附加协议吗' '这么牛逼，以前没托关系给你弄套样板房吗？'\n",
      " '笕桥机场都是什么军机？有没有j20？' '小鸡选房二手房怎么查' '不是说早就爆了吗' '请问为啥下架了？'\n",
      " '很好奇这些人才的收入，不知道前面有展示么' '想冲什么随意' '啊 拿什么卡贷款？？？']\n",
      "need_data ['开发商，是谁？' '需要摇号，与否？' '交付时间，什么时候？' '样板房参观，开放参观吗？' '开盘楼栋，那几栋？'\n",
      " '销售，有吗？销售联系方式' '房源优惠，有吗？' '不利因素' '车位绑定，是否绑定销售？' '周边规划，是什么？']\n"
     ]
    }
   ],
   "source": [
    "np.random.shuffle(not_data)\n",
    "np.random.shuffle(need_data)\n",
    "print('not_data', not_data[0:10])\n",
    "print('need_data', need_data[0:10])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "pre_train = []\n",
    "pre_target = []\n",
    "for item in not_data:\n",
    "    pre_train.append(str(item))\n",
    "    pre_target.append('1')\n",
    "for item in need_data:\n",
    "    pre_train.append(str(item))\n",
    "    pre_target.append('0')\n",
    "not_data = []\n",
    "need_data = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "len x_train 44626\n",
      "len x_text 11157\n"
     ]
    }
   ],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "x_train,x_test,train_label,test_label = train_test_split(pre_train,pre_target,test_size=0.2,random_state=5)\n",
    "print('len x_train', len(x_train))\n",
    "print('len x_text', len(x_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\conda\\envs\\newbert\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
    "from transformers import BertTokenizer\n",
    "tokenizer = BertTokenizer.from_pretrained('./bert-base-chinese')\n",
    "train_encoding = tokenizer(x_train, truncation=True, padding=True, max_length=128)\n",
    "test_encoding = tokenizer(x_test, truncation=True, padding=True, max_length=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据集读取\n",
    "class NewsDataset(Dataset):\n",
    "    def __init__(self, encodings, labels):\n",
    "        self.encodings = encodings\n",
    "        self.labels = labels\n",
    "    \n",
    "    # 读取单个样本\n",
    "    def __getitem__(self, idx):\n",
    "        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
    "        item['labels'] = torch.tensor(int(self.labels[idx]))\n",
    "        return item\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.labels)\n",
    "\n",
    "train_dataset = NewsDataset(train_encoding, train_label)\n",
    "test_dataset = NewsDataset(test_encoding, test_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\conda\\envs\\newbert\\lib\\site-packages\\transformers\\optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from transformers import BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup\n",
    "model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=2)\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "print(device)\n",
    "# 单个读取到批量读取\n",
    "train_loader = DataLoader(train_dataset, batch_size=64)\n",
    "test_dataloader = DataLoader(test_dataset, batch_size=64)\n",
    "\n",
    "# 优化方法\n",
    "epoches = 2\n",
    "no_decay = ['bias', 'LayerNorm.weight']\n",
    "optimizer_grouped_parameters = [\n",
    "    # Filter for all parameters which *don't* include 'bias', 'gamma', 'beta'.\n",
    "    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n",
    "     'weight_decay_rate': 0.1},\n",
    "    \n",
    "    # Filter for parameters which *do* include those.\n",
    "    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n",
    "     'weight_decay_rate': 0.0}\n",
    "]\n",
    "optim = AdamW(optimizer_grouped_parameters, lr=2e-5)\n",
    "total_steps = len(train_loader) * epoches\n",
    "scheduler = get_linear_schedule_with_warmup(optim, \n",
    "                                            num_warmup_steps = 0, # Default value in run_glue.py\n",
    "                                            num_training_steps = total_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The BERT model has 201 different named parameters.\n",
      "\n",
      "==== Embedding Layer ====\n",
      "\n",
      "bert.embeddings.word_embeddings.weight                  (21128, 768)\n",
      "bert.embeddings.position_embeddings.weight                (512, 768)\n",
      "bert.embeddings.token_type_embeddings.weight                (2, 768)\n",
      "bert.embeddings.LayerNorm.weight                              (768,)\n",
      "bert.embeddings.LayerNorm.bias                                (768,)\n",
      "\n",
      "==== First Transformer ====\n",
      "\n",
      "bert.encoder.layer.0.attention.self.query.weight          (768, 768)\n",
      "bert.encoder.layer.0.attention.self.query.bias                (768,)\n",
      "bert.encoder.layer.0.attention.self.key.weight            (768, 768)\n",
      "bert.encoder.layer.0.attention.self.key.bias                  (768,)\n",
      "bert.encoder.layer.0.attention.self.value.weight          (768, 768)\n",
      "bert.encoder.layer.0.attention.self.value.bias                (768,)\n",
      "bert.encoder.layer.0.attention.output.dense.weight        (768, 768)\n",
      "bert.encoder.layer.0.attention.output.dense.bias              (768,)\n",
      "bert.encoder.layer.0.attention.output.LayerNorm.weight        (768,)\n",
      "bert.encoder.layer.0.attention.output.LayerNorm.bias          (768,)\n",
      "bert.encoder.layer.0.intermediate.dense.weight           (3072, 768)\n",
      "bert.encoder.layer.0.intermediate.dense.bias                 (3072,)\n",
      "bert.encoder.layer.0.output.dense.weight                 (768, 3072)\n",
      "bert.encoder.layer.0.output.dense.bias                        (768,)\n",
      "bert.encoder.layer.0.output.LayerNorm.weight                  (768,)\n",
      "bert.encoder.layer.0.output.LayerNorm.bias                    (768,)\n",
      "\n",
      "==== Output Layer ====\n",
      "\n",
      "bert.pooler.dense.weight                                  (768, 768)\n",
      "bert.pooler.dense.bias                                        (768,)\n",
      "classifier.weight                                           (2, 768)\n",
      "classifier.bias                                                 (2,)\n"
     ]
    }
   ],
   "source": [
    "params = list(model.named_parameters())\n",
    "\n",
    "print('The BERT model has {:} different named parameters.\\n'.format(len(params)))\n",
    "\n",
    "print('==== Embedding Layer ====\\n')\n",
    "\n",
    "for p in params[0:5]:\n",
    "    print(\"{:<55} {:>12}\".format(p[0], str(tuple(p[1].size()))))\n",
    "\n",
    "print('\\n==== First Transformer ====\\n')\n",
    "\n",
    "for p in params[5:21]:\n",
    "    print(\"{:<55} {:>12}\".format(p[0], str(tuple(p[1].size()))))\n",
    "\n",
    "print('\\n==== Output Layer ====\\n')\n",
    "\n",
    "for p in params[-4:]:\n",
    "    print(\"{:<55} {:>12}\".format(p[0], str(tuple(p[1].size()))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def flat_accuracy(preds, labels):\n",
    "    pred_flat = np.argmax(preds, axis=1).flatten()\n",
    "    labels_flat = labels.flatten()\n",
    "    return np.sum(pred_flat == labels_flat) / len(labels_flat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import datetime\n",
    "\n",
    "def format_time(elapsed):\n",
    "    '''\n",
    "    Takes a time in seconds and returns a string hh:mm:ss\n",
    "    '''\n",
    "    elapsed_rounded = int(round((elapsed)))\n",
    "    \n",
    "    return str(datetime.timedelta(seconds=elapsed_rounded))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 训练函数\n",
    "def train():\n",
    "    model.train()\n",
    "    total_train_loss = 0\n",
    "    iter_num = 0\n",
    "    total_iter = len(train_loader)\n",
    "    for batch in train_loader:\n",
    "        # 正向传播\n",
    "        optim.zero_grad()\n",
    "        input_ids = batch['input_ids'].to(device)\n",
    "        attention_mask = batch['attention_mask'].to(device)\n",
    "        labels = batch['labels'].to(device)\n",
    "        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
    "        loss = outputs[0]\n",
    "        total_train_loss += loss.item()\n",
    "        \n",
    "        # 反向梯度信息\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
    "        \n",
    "        # 参数更新\n",
    "        optim.step()\n",
    "        scheduler.step()\n",
    "\n",
    "        iter_num += 1\n",
    "        if(iter_num % 100==0):\n",
    "            print(\"epoth: %d, iter_num: %d, loss: %.4f, %.2f%%\" % (epoch, iter_num, loss.item(), iter_num/total_iter*100))\n",
    "\n",
    "        \n",
    "    print(\"Epoch: %d, Average training loss: %.4f\"%(epoch, total_train_loss/len(train_loader)))\n",
    "validation_l = []\n",
    "def validation():\n",
    "    model.eval()\n",
    "    t0 = time.time()\n",
    "    total_eval_accuracy = 0\n",
    "    total_eval_loss = 0\n",
    "    for batch in test_dataloader:\n",
    "        with torch.no_grad():\n",
    "            # 正常传播\n",
    "            input_ids = batch['input_ids'].to(device)\n",
    "            attention_mask = batch['attention_mask'].to(device)\n",
    "            labels = batch['labels'].to(device)\n",
    "            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
    "        \n",
    "        loss = outputs[0]\n",
    "        logits = outputs[1]\n",
    "        total_eval_loss += loss.item()\n",
    "        logits = logits.detach().cpu().numpy()\n",
    "        label_ids = labels.to('cpu').numpy()\n",
    "        total_eval_accuracy += flat_accuracy(logits, label_ids)\n",
    "        validation_l.append([outputs,logits])\n",
    "        \n",
    "    validation_time = format_time(time.time() - t0)\n",
    "    avg_val_accuracy = total_eval_accuracy / len(test_dataloader)\n",
    "    print(\"Accuracy: %.4f\" % (avg_val_accuracy))\n",
    "    print(\"Average testing loss: %.4f\"%(total_eval_loss/len(test_dataloader)))\n",
    "    print(\"Validation took: {:}\".format(validation_time))\n",
    "    print(\"-------------------------------\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------Epoch: 0 ----------------\n",
      "epoth: 0, iter_num: 100, loss: 0.5523, 14.33%\n",
      "epoth: 0, iter_num: 200, loss: 0.3549, 28.65%\n",
      "epoth: 0, iter_num: 300, loss: 0.3394, 42.98%\n",
      "epoth: 0, iter_num: 400, loss: 0.4074, 57.31%\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800000; text-decoration-color: #800000\">╭─────────────────────────────── </span><span style=\"color: #800000; text-decoration-color: #800000; font-weight: bold\">Traceback </span><span style=\"color: #bf7f7f; text-decoration-color: #bf7f7f; font-weight: bold\">(most recent call last)</span><span style=\"color: #800000; text-decoration-color: #800000\"> ────────────────────────────────╮</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">&lt;module&gt;</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">3</span>                                                                                    <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>                                                                                                  <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">1 </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">for</span> epoch <span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">in</span> <span style=\"color: #00ffff; text-decoration-color: #00ffff\">range</span>(<span style=\"color: #0000ff; text-decoration-color: #0000ff\">10</span>):                                                                      <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">2 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   </span><span style=\"color: #00ffff; text-decoration-color: #00ffff\">print</span>(<span style=\"color: #808000; text-decoration-color: #808000\">\"------------Epoch: %d ----------------\"</span> % epoch)                                  <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>3 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   </span>train()                                                                                  <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">4 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># validation()</span>                                                                           <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">5 </span>                                                                                             <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>                                                                                                  <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">train</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">18</span>                                                                                      <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>                                                                                                  <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">15 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   </span>total_train_loss += loss.item()                                                     <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">16 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   </span>                                                                                    <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">17 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># 反向梯度信息</span>                                                                      <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>18 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   </span>loss.backward()                                                                     <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">19 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   </span>torch.nn.utils.clip_grad_norm_(model.parameters(), <span style=\"color: #0000ff; text-decoration-color: #0000ff\">1.0</span>)                             <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">20 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   </span>                                                                                    <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">21 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># 参数更新</span>                                                                          <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>                                                                                                  <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #808000; text-decoration-color: #808000\">d:\\conda\\envs\\newbert\\lib\\site-packages\\torch\\_tensor.py</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">487</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">backward</span>                         <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>                                                                                                  <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 484 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   │   │   </span>create_graph=create_graph,                                                <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 485 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   │   │   </span>inputs=inputs,                                                            <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 486 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   │   </span>)                                                                             <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span> 487 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   </span>torch.autograd.backward(                                                          <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 488 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   │   </span><span style=\"color: #00ffff; text-decoration-color: #00ffff\">self</span>, gradient, retain_graph, create_graph, inputs=inputs                     <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 489 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   </span>)                                                                                 <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 490 </span>                                                                                          <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>                                                                                                  <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #808000; text-decoration-color: #808000\">d:\\conda\\envs\\newbert\\lib\\site-packages\\torch\\autograd\\__init__.py</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">200</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">backward</span>               <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>                                                                                                  <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">197 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># The reason we repeat same the comment below is that</span>                                  <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">198 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># some Python versions print out the first line of a multi-line function</span>               <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">199 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># calls in the traceback and some print out the last line</span>                              <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>200 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   </span>Variable._execution_engine.run_backward(  <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># Calls into the C++ engine to run the bac</span>   <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">201 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   </span>tensors, grad_tensors_, retain_graph, create_graph, inputs,                        <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">202 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   </span>allow_unreachable=<span style=\"color: #0000ff; text-decoration-color: #0000ff\">True</span>, accumulate_grad=<span style=\"color: #0000ff; text-decoration-color: #0000ff\">True</span>)  <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># Calls into the C++ engine to ru</span>   <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">203 </span>                                                                                           <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">╰──────────────────────────────────────────────────────────────────────────────────────────────────╯</span>\n",
       "<span style=\"color: #ff0000; text-decoration-color: #ff0000; font-weight: bold\">KeyboardInterrupt</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n",
       "\u001b[31m│\u001b[0m in \u001b[92m<module>\u001b[0m:\u001b[94m3\u001b[0m                                                                                    \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m                                                                                                  \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m1 \u001b[0m\u001b[94mfor\u001b[0m epoch \u001b[95min\u001b[0m \u001b[96mrange\u001b[0m(\u001b[94m10\u001b[0m):                                                                      \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m2 \u001b[0m\u001b[2m│   \u001b[0m\u001b[96mprint\u001b[0m(\u001b[33m\"\u001b[0m\u001b[33m------------Epoch: \u001b[0m\u001b[33m%d\u001b[0m\u001b[33m ----------------\u001b[0m\u001b[33m\"\u001b[0m % epoch)                                  \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m3 \u001b[2m│   \u001b[0mtrain()                                                                                  \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m4 \u001b[0m\u001b[2m│   \u001b[0m\u001b[2m# validation()\u001b[0m                                                                           \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m5 \u001b[0m                                                                                             \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m                                                                                                  \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m in \u001b[92mtrain\u001b[0m:\u001b[94m18\u001b[0m                                                                                      \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m                                                                                                  \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m15 \u001b[0m\u001b[2m│   │   \u001b[0mtotal_train_loss += loss.item()                                                     \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m16 \u001b[0m\u001b[2m│   │   \u001b[0m                                                                                    \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m17 \u001b[0m\u001b[2m│   │   \u001b[0m\u001b[2m# 反向梯度信息\u001b[0m                                                                      \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m18 \u001b[2m│   │   \u001b[0mloss.backward()                                                                     \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m19 \u001b[0m\u001b[2m│   │   \u001b[0mtorch.nn.utils.clip_grad_norm_(model.parameters(), \u001b[94m1.0\u001b[0m)                             \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m20 \u001b[0m\u001b[2m│   │   \u001b[0m                                                                                    \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m21 \u001b[0m\u001b[2m│   │   \u001b[0m\u001b[2m# 参数更新\u001b[0m                                                                          \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m                                                                                                  \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m \u001b[33md:\\conda\\envs\\newbert\\lib\\site-packages\\torch\\_tensor.py\u001b[0m:\u001b[94m487\u001b[0m in \u001b[92mbackward\u001b[0m                         \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m                                                                                                  \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m 484 \u001b[0m\u001b[2m│   │   │   │   \u001b[0mcreate_graph=create_graph,                                                \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m 485 \u001b[0m\u001b[2m│   │   │   │   \u001b[0minputs=inputs,                                                            \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m 486 \u001b[0m\u001b[2m│   │   │   \u001b[0m)                                                                             \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 487 \u001b[2m│   │   \u001b[0mtorch.autograd.backward(                                                          \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m 488 \u001b[0m\u001b[2m│   │   │   \u001b[0m\u001b[96mself\u001b[0m, gradient, retain_graph, create_graph, inputs=inputs                     \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m 489 \u001b[0m\u001b[2m│   │   \u001b[0m)                                                                                 \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m 490 \u001b[0m                                                                                          \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m                                                                                                  \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m \u001b[33md:\\conda\\envs\\newbert\\lib\\site-packages\\torch\\autograd\\__init__.py\u001b[0m:\u001b[94m200\u001b[0m in \u001b[92mbackward\u001b[0m               \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m                                                                                                  \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m197 \u001b[0m\u001b[2m│   \u001b[0m\u001b[2m# The reason we repeat same the comment below is that\u001b[0m                                  \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m198 \u001b[0m\u001b[2m│   \u001b[0m\u001b[2m# some Python versions print out the first line of a multi-line function\u001b[0m               \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m199 \u001b[0m\u001b[2m│   \u001b[0m\u001b[2m# calls in the traceback and some print out the last line\u001b[0m                              \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m200 \u001b[2m│   \u001b[0mVariable._execution_engine.run_backward(  \u001b[2m# Calls into the C++ engine to run the bac\u001b[0m   \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m201 \u001b[0m\u001b[2m│   │   \u001b[0mtensors, grad_tensors_, retain_graph, create_graph, inputs,                        \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m202 \u001b[0m\u001b[2m│   │   \u001b[0mallow_unreachable=\u001b[94mTrue\u001b[0m, accumulate_grad=\u001b[94mTrue\u001b[0m)  \u001b[2m# Calls into the C++ engine to ru\u001b[0m   \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m203 \u001b[0m                                                                                           \u001b[31m│\u001b[0m\n",
       "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n",
       "\u001b[1;91mKeyboardInterrupt\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for epoch in range(10):\n",
    "    print(\"------------Epoch: %d ----------------\" % epoch)\n",
    "    train()\n",
    "    # validation()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model, './sec_bert_new.ckpt')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "newbert",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
